import java.util.*;

 class ListNode {
    int val;
    ListNode next = null;

    ListNode(int val) {
        this.val = val;
    }
}
public class Partition {
    public ListNode partition(ListNode pHead, int x) {
        ListNode bs = pHead;
        ListNode be = bs.next;
        ListNode as = bs.next;
        ListNode ae = as.next;
        ListNode cur = pHead;
        while(cur != null) {
            if(cur.val< x) {
                bs = cur;
                bs = be;
            }else {
                as = cur;
                as = ae;
            }
            cur = cur.next;
        }
        return bs;
    }

    public static void main(String[] args) {

    }
}
